import numpy as np
import torch

class CVRPEnv:
    def __init__(self, num_nodes, capacity, depot_idx=0):
        """
        Args:
            num_nodes (int): 客户点数量 (不包括仓库)。总节点数 = num_nodes + 1 (仓库)
            capacity (float): 车辆最大容量。
            depot_idx (int): 仓库在节点列表中的索引。
        """
        self.num_nodes = num_nodes
        self.capacity = capacity
        self.depot_idx = depot_idx
        self.coords = None # 节点坐标 (包括仓库)
        self.demands = None # 客户需求 (仓库需求为0)

        self.current_vehicle_capacity = None
        self.visited_mask = None # 标记已访问节点
        self.current_node = None # 当前车辆所在节点
        self.total_distance = 0.0
        self.num_vehicles_used = 1 # 初始使用一辆车

        self.problem_size = num_nodes + 1 # 总节点数，包括仓库

    def _euclidean_distance(self, p1, p2):
        """计算两点间的欧几里得距离。"""
        return np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)

    def reset(self, coords, demands):
        """
        重置环境，开始一个新的 VRP 实例。
        Args:
            coords (np.ndarray): (num_nodes + 1, 2) 形状的坐标数组。
            demands (np.ndarray): (num_nodes + 1,) 形状的需求数组。
        """
        self.coords = coords
        self.demands = demands

        self.current_vehicle_capacity = self.capacity
        self.visited_mask = np.zeros(self.problem_size, dtype=bool)
        self.current_node = self.depot_idx # 从仓库开始
        self.total_distance = 0.0
        self.num_vehicles_used = 1 # 初始使用一辆车

        # 仓库被视为已访问，但不能占用容量或需求
        self.visited_mask[self.depot_idx] = True

        # 返回初始状态 (当前节点, 访问掩码, 剩余容量, 当前距离, 已使用的车辆数)
        state = {
            'current_node': self.current_node,
            'visited_mask': self.visited_mask.copy(),
            'current_capacity': self.current_vehicle_capacity,
            'total_distance': self.total_distance,
            'num_vehicles_used': self.num_vehicles_used,
            'coords': self.coords,
            'demands': self.demands
        }
        return state

    def step(self, action):
        """
        执行一步动作。
        Args:
            action (int): 下一个要访问的节点索引。
        Returns:
            next_state (dict): 新的状态。
            reward (float): 奖励。
            done (bool): 是否结束。
            info (dict): 额外信息。
        """
        reward = 0.0
        done = False
        info = {}

        # 计算移动距离
        dist = self._euclidean_distance(self.coords[self.current_node], self.coords[action])
        self.total_distance += dist

        # 更新当前节点
        self.current_node = action

        # 处理动作：
        if action == self.depot_idx: # 回到仓库
            # 只有当所有客户都被服务或当前车辆无法服务任何剩余客户时才允许回到仓库
            if np.all(self.visited_mask[1:]): # 如果所有客户都已访问
                done = True # 结束一个 episode
            else: # 如果还有未访问客户，则启动新车辆
                self.num_vehicles_used += 1
                self.current_vehicle_capacity = self.capacity # 新车满载
                # 新车从仓库出发，所以不需要更新 current_node
        else: # 访问客户
            demand = self.demands[action]
            if self.current_vehicle_capacity >= demand and not self.visited_mask[action]:
                self.current_vehicle_capacity -= demand
                self.visited_mask[action] = True
            else:
                # 不可行动作：如果动作不合法，给予巨大惩罚
                reward = -1e9 # 巨大的惩罚
                done = True # 提前终止此 episode，表示一个失败的尝试
                info['violation'] = True
                print(f"Violation: capacity {self.current_vehicle_capacity} < demand {demand} or already visited {action}")

        # 检查是否所有客户都已访问
        if np.all(self.visited_mask[1:]): # 排除仓库
            # 确保最后回到仓库以完成所有任务
            if self.current_node != self.depot_idx:
                self.total_distance += self._euclidean_distance(self.coords[self.current_node], self.coords[self.depot_idx])
            self.current_node = self.depot_idx # 确保结束在仓库
            done = True


        # 奖励计算：
        # 这里为了简化，每次 step 的奖励可以为负的距离，或在 done 时给出总奖励
        if done:
            # 目标是最小化总距离和车辆数量
            # 奖励 = - (总距离 + alpha * 额外车辆数)
            reward = - (self.total_distance + (self.num_vehicles_used - 1) * 100) # 假设每辆额外车辆惩罚100单位
            if 'violation' in info: # 如果有违规，则惩罚更多
                reward -= 10000
        else:
            reward = 0 # 每次步进没有即时奖励，只在完成时

        next_state = {
            'current_node': self.current_node,
            'visited_mask': self.visited_mask.copy(),
            'current_capacity': self.current_vehicle_capacity,
            'total_distance': self.total_distance,
            'num_vehicles_used': self.num_vehicles_used,
            'coords': self.coords,
            'demands': self.demands
        }
        return next_state, reward, done, info

    def get_action_mask(self):
        """
        返回一个布尔掩码，指示哪些节点是当前可行的动作。
        - 已访问的客户不可选
        - 容量不足的客户不可选
        - 如果所有未访问客户都不可达（容量不足），则只能回仓库
        """
        mask = np.zeros(self.problem_size, dtype=bool)
        # 初始：所有未访问客户点都可以考虑
        mask[~self.visited_mask] = True

        # 容量过滤：移除当前车辆无法满足需求的客户
        for i in range(self.problem_size):
            if mask[i] and self.demands[i] > self.current_vehicle_capacity:
                mask[i] = False

        # 检查是否必须回仓库
        if np.all(~mask[1:]): # 如果除了仓库之外，所有未访问客户都不可达（被访问或容量不足）
            mask[:] = False # 清除所有
            mask[self.depot_idx] = True # 只能回仓库
        
        # 确保已访问的客户不能再次被访问 (除了仓库可以多次访问)
        mask[self.visited_mask & (np.arange(self.problem_size) != self.depot_idx)] = False

        return mask
